{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#include "caspar_mappings_pybinding.h"
#include "pybind_array_tools.h"
#include "shared_indices_pybinding.h"
#include "sort_indices_pybinding.h"

{% for kernel in caslib.kernels %}
{% if kernel.expose_to_python %}
#include "kernel_{{kernel.name}}.h"
{% endif %}
{% endfor %}

{% if solver is not none %}
#include "solver_pybinding.h"
#include "solver_params_pybinding.h"
{% endif %}

{% if caslib.kernels %}
namespace {

using namespace caspar;

{% for kernel in caslib.kernels %}
{% if kernel.expose_to_python %}
void {{kernel.name}}_pybinding(
    {% for accessor in kernel.accessors %}
    {% for name, typ in accessor.py_sig.items() %}
    {{typ}} {{name}},
    {% endfor %}
    {% endfor %}
    size_t problem_size) {
  {% for accessor in kernel.accessors %}
  {% for name in accessor.py_sig %}
  AssertDeviceMemory({{name}});
  {% endfor %}
  {% endfor %}

  {{kernel.name}}(
      {% for accessor in kernel.accessors %}
      {% for arg in accessor.py_args %}
      {{arg}},
      {% endfor %}
      {% endfor %}
      problem_size
  );
}
{% endif %}

{% endfor %}
}  // namespace

{% endif %}  {# caslib.kernels #}

PYBIND11_MODULE({{caslib.name}}, module) {
  module.doc() = "Module containing bindings for a generated caspar library. See the generated pyi file for more details.";
  module.def("shared_indices", &caspar::shared_indices_pybinding);
  {% for kernel in caslib.kernels %}
  {% if kernel.expose_to_python %}
  module.def("{{kernel.name}}", &{{kernel.name}}_pybinding);
  {% endif %}
  {% endfor %}

  caspar::add_casmappings_pybindings(module);
  {% if solver is not none %}
  caspar::add_solver_pybinding(module);
  caspar::add_solver_params_pybinding(module);
  {% endif %}
}
